{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install langchain_google_vertexai langgraph langchain_community"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Natural Language to SQL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to demonstrate a Natural Language to SQL architecture, lets first create a sample database.\n",
    "\n",
    "We wull be using sqlite with the `langchain_community` `Database` interface.\n",
    "\n",
    "Let's create an empty database in the local file `database.sqlite`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "db = SQLDatabase.from_uri(\"sqlite:///database.sqlite\")\n",
    "\n",
    "result = db.run(\n",
    "    \"\"\"\n",
    "      CREATE TABLE ITEMS\n",
    "      (name TEXT, type TEXT, color TEXT, season TEXT, price FLOAT, description TEXT)\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Bear in mind that if the file already exists, you can get an error.\n",
    "\n",
    "In order to populate the table, let's now read the data from `data.json`. This file contains item descriptions with some metadta, such as the type, color, season and prince of each item. \n",
    "\n",
    "Once the data is read, we insert each record in the database we just created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open(\"data.json\") as input_file:\n",
    "    items = json.load(input_file)\n",
    "\n",
    "sanitize = lambda string: string.replace(\"'\", \"\") # Remove apostrophe\n",
    "records = [\n",
    "    \", \".join(\n",
    "        f\"'{sanitize(value)}'\" if isinstance(value, str) else str(value)\n",
    "        for colname, value in item.items()\n",
    "    ) \n",
    "    for item in items\n",
    "]\n",
    "\n",
    "for record in records:\n",
    "    statement = f\"INSERT INTO ITEMS VALUES ({record})\"\n",
    "    try:\n",
    "        db.run(statement)\n",
    "    except Exception as err:\n",
    "        print(statement)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define the tool the agent is going to use. The docstring of the function is the description that the model is going to use as the prompt, so is important for it to be descriptive. We include the schema of the database to make sure the agent knows which columns and datatypes exist."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "@tool\n",
    "def execute_query(query: str) -> str:\n",
    "    \"\"\" Executes a query against the clothing items database.\n",
    "    The schema of the database is the following:\n",
    "    ```\n",
    "        CREATE TABLE ITEMS\n",
    "        (\n",
    "            name TEXT, -- Name of the item\n",
    "            type TEXT, -- Type of the item can be 't-shirt', 'pants' or 'coat'\n",
    "            color TEXT, -- Color of the item. Can be \"red\", \"blue\", \"black\" or \"white\"\n",
    "            season TEXT, -- Can be \"summer\", \"winter\" \"fall\" or \"spring\"\n",
    "            price FLOAT, -- Prize in dollars.\n",
    "            description TEXT -- Description of the item\n",
    "        )\n",
    "    \"\"\"\n",
    "    return db.run_no_throw(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we create the agent architecture, just as described in the book Chapter, and compile the graph to make it invokable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "from langchain_google_vertexai import ChatVertexAI\n",
    "from langchain_core.messages import AnyMessage\n",
    "from langgraph.graph import StateGraph, MessagesState, END\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "\n",
    "\n",
    "graph = StateGraph(MessagesState)\n",
    "\n",
    "tools = [execute_query]\n",
    "\n",
    "generator = ChatVertexAI(\n",
    "  model_name=\"gemini-1.5-pro-001\", \n",
    "  temperature = 0\n",
    ").bind(tools=tools)\n",
    "\n",
    "def invoke_generator(state: MessagesState) -> None:\n",
    "    \"\"\" Represents the generator node.\n",
    "    \"\"\"\n",
    "    response = generator.invoke(state[\"messages\"])\n",
    "    state[\"messages\"].append(response)\n",
    "    \n",
    "def use_execute_query(state: MessagesState) -> Literal[\"query_tool\", END]:\n",
    "    \"\"\" Represents the use_retrieval conditional edge\n",
    "    \"\"\"\n",
    "    if not state[\"messages\"]:\n",
    "        return END\n",
    "    if state[\"messages\"][-1].tool_calls:\n",
    "        return \"query_tool\"\n",
    "    return END\n",
    "\n",
    "\n",
    "graph.add_node(\"generator\", invoke_generator)\n",
    "graph.add_node(\"query_tool\", ToolNode(tools))\n",
    "graph.set_entry_point(\"generator\")\n",
    "\n",
    "graph.add_conditional_edges(\"generator\", use_execute_query)\n",
    "graph.add_edge(\"query_tool\", \"generator\")\n",
    "\n",
    "\n",
    "agentic_sql = graph.compile()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can query the agent with the `stream` method and obserb how it generates the query and executes it in the dabase, retrieving the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Can you show me the name and price of all coats?\n",
      "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
      "Name: execute_query\n",
      "\n",
      "[('Midnight Magic Coat', 89.99), ('Autumn Hues Sweater', 49.99), ('Winter Wonderland Jacket', 79.99), ('Cozy Cabin Cardigan', 59.99), ('Fall Foliage Vest', 49.99), ('Snowy Peaks Parka', 99.99)]\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Here are the coats with their prices:\n",
      "\n",
      "* Midnight Magic Coat - $89.99\n",
      "* Autumn Hues Sweater - $49.99\n",
      "* Winter Wonderland Jacket - $79.99\n",
      "* Cozy Cabin Cardigan - $59.99\n",
      "* Fall Foliage Vest - $49.99\n",
      "* Snowy Peaks Parka - $99.99\n"
     ]
    }
   ],
   "source": [
    "\n",
    "messages = [\n",
    "    SystemMessage(\n",
    "        \"\"\"\n",
    "        - You are a useful assistant that help users navigate a catalog of clothing items.\n",
    "        - You can retrieve clothing items from the catalog using a SQL query.\n",
    "        - Answer in natural language and format your output using paragraph or bullet points if\n",
    "          necessary.\n",
    "        \"\"\"\n",
    "    ),\n",
    "    HumanMessage(\"Can you show me the name and price of all coats?\")\n",
    "]\n",
    "\n",
    "stream_generator = agentic_sql.stream({\"messages\": messages}, stream_mode=\"values\")\n",
    "\n",
    "for state in stream_generator:\n",
    "    state[\"messages\"][-1].pretty_print()\n",
    "state[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It can also execute calculations leveraging `SQL` capabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "What is the name and price in euros of the most expensive coat, given that 1 euro is currently 1.1 dollars?\n",
      "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
      "Name: execute_query\n",
      "\n",
      "[('Snowy Peaks Parka', 109.989)]\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "The most expensive coat is the Snowy Peaks Parka and its price is 109.99 euros.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "messages = [\n",
    "    SystemMessage(\n",
    "        \"\"\"\n",
    "        - You are a useful assistant that help users navigate a catalog of clothing items.\n",
    "        - You can retrieve clothing items from the catalog using a SQL query.\n",
    "        - Answer in natural language and format your output using paragraph or bullet points if\n",
    "          necessary.\n",
    "        \"\"\"\n",
    "    ),\n",
    "    HumanMessage(\"What is the name and price in euros of the most expensive coat, given that 1 euro is currently 1.1 dollars?\")\n",
    "]\n",
    "\n",
    "stream_generator = agentic_sql.stream({\"messages\": messages}, stream_mode=\"values\")\n",
    "\n",
    "for state in stream_generator:\n",
    "    state[\"messages\"][-1].pretty_print()\n",
    "state[\"messages\"][-1].pretty_print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
